{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp data.transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time Series Data Augmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Functions used to transform TSTensors (Data Augmentation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from scipy.interpolate import CubicSpline\n",
    "from scipy.ndimage import convolve1d\n",
    "from fastcore.transform import compose_tfms\n",
    "from fastai.vision.augment import RandTransform\n",
    "from tsai.utils import *\n",
    "from tsai.data.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.data.core import TSCategorize\n",
    "from tsai.data.external import get_UCR_data\n",
    "from tsai.data.preprocessing import TSStandardize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsid = 'NATOPS'\n",
    "X, y, splits = get_UCR_data(dsid, return_split=False)\n",
    "tfms = [None, TSCategorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms, bs=128)\n",
    "xb, yb = next(iter(dls.train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSIdentity(RandTransform):\n",
    "    \"Applies the identity tfm to a `TSTensor` batch\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=None, **kwargs):\n",
    "        self.magnitude = magnitude\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor): return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSIdentity()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "# partial(TSShuffle_HLs, ex=0),\n",
    "class TSShuffle_HLs(RandTransform):\n",
    "    \"Randomly shuffles HIs/LOs of an OHLC `TSTensor` batch\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1., ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        timesteps = o.shape[-1] // 4\n",
    "        pos_rand_list = random_choice(np.arange(timesteps),size=random.randint(1, timesteps),replace=False)\n",
    "        rand_list = pos_rand_list * 4\n",
    "        highs = rand_list + 1\n",
    "        lows = highs + 1\n",
    "        a = np.vstack([highs, lows]).flatten('F')\n",
    "        b = np.vstack([lows, highs]).flatten('F')\n",
    "        output = o.clone()\n",
    "        output[...,a] = output[...,b]\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSShuffle_HLs()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "# partial(TSShuffleSteps, ex=0),\n",
    "class TSShuffleSteps(RandTransform):\n",
    "    \"Randomly shuffles consecutive sequence datapoints in batch\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1., ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        odd = 1 - o.shape[-1]%2\n",
    "        r = np.random.randint(2)\n",
    "        timesteps = o.shape[-1] // 2\n",
    "        pos_rand_list = random_choice(np.arange(0, timesteps - r * odd), size=random.randint(1, timesteps - r * odd),replace=False) * 2 + 1 + r\n",
    "        a = np.vstack([pos_rand_list, pos_rand_list - 1]).flatten('F')\n",
    "        b = np.vstack([pos_rand_list - 1, pos_rand_list]).flatten('F')\n",
    "        output = o.clone()\n",
    "        output[...,a] = output[...,b]\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiPElEQVR4nO3dfXST9f3/8VdvaFpu0tpqEzraWjc2qNwplRJhm0JHxcqR0aPiqdghB85YikDPELtxt6JUmRMGKyAeBnhGh/IHKAyRUrQcR1ugjB0Ehzj5nnZi2m3YBrpDWtr8/tiPbBFQA9F80j0f51znkOv6pHlfOUCeJ73SRni9Xq8AAAAMEhnqAQAAAD6LQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgnOhQD3A9urq6dPbsWfXp00cRERGhHgcAAHwJXq9X58+fV0pKiiIjP/89krAMlLNnzyo1NTXUYwAAgOvQ2Niofv36fe6asAyUPn36SPr3CVqt1hBPAwAAvgy3263U1FTf6/jnCctAufxtHavVSqAAABBmvszlGVwkCwAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA40SHegB0b7c+/YdQjyBJ+r/n8kI9AgAgAARKmOKFHwDQnQX8LZ6PP/5Yjz32mJKSkhQXF6fBgwfryJEjvuNer1eLFi1S3759FRcXp5ycHJ0+fdrva5w7d04FBQWyWq1KSEjQtGnTdOHChRs/GwAA0C0E9A7Kp59+qlGjRunee+/Vm2++qVtuuUWnT5/WTTfd5FuzfPlyrVq1Sps3b1ZGRoYWLlyo3NxcnTx5UrGxsZKkgoICffLJJ6qsrFRHR4emTp2qGTNmqKKiIrhnBwDA/8c7z+EloEB5/vnnlZqaqo0bN/r2ZWRk+P7s9Xq1cuVKLViwQA8++KAk6ZVXXpHNZtOOHTs0efJkvf/++9qzZ48OHz6srKwsSdLq1at1//3364UXXlBKSkowzgsIiAn/cfGfFgD8R0CB8sYbbyg3N1cPPfSQqqur9Y1vfEM/+clPNH36dEnSmTNn5HK5lJOT47tPfHy8srOzVVNTo8mTJ6umpkYJCQm+OJGknJwcRUZGqq6uTj/84Q+veFyPxyOPx+O77Xa7Az5RAPhvJkSpRJgC1xLQNSgfffSR1q5dq/79++utt97SzJkz9eSTT2rz5s2SJJfLJUmy2Wx+97PZbL5jLpdLycnJfsejo6OVmJjoW/NZZWVlio+P922pqamBjA0AAMJMQIHS1dWlO++8U8uWLdMdd9yhGTNmaPr06Vq3bt1XNZ8kqaSkRK2trb6tsbHxK308AAAQWgEFSt++fZWZmem3b+DAgWpoaJAk2e12SVJTU5PfmqamJt8xu92u5uZmv+OXLl3SuXPnfGs+y2KxyGq1+m0AAKD7CihQRo0apVOnTvnt++CDD5Seni7p3xfM2u12VVVV+Y673W7V1dXJ4XBIkhwOh1paWlRfX+9bs3//fnV1dSk7O/u6TwQAAHQfAV0kO3fuXN19991atmyZHn74YR06dEjr16/X+vXrJUkRERGaM2eOnnnmGfXv39/3MeOUlBRNnDhR0r/fcbnvvvt83xrq6OhQUVGRJk+ezCd4AOAzTLiYlwt5EQoBBcpdd92l7du3q6SkRKWlpcrIyNDKlStVUFDgW/PUU0+pra1NM2bMUEtLi0aPHq09e/b4fgaKJG3ZskVFRUUaO3asIiMjlZ+fr1WrVgXvrAAAXytCCsEW8I+6f+CBB/TAAw9c83hERIRKS0tVWlp6zTWJiYn8UDYAAHBN/DZjAABgHH5ZIICgMuGtfom3+4FwR6AAYcSEF39e+AF8HfgWDwAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4AQXKkiVLFBER4bcNGDDAd/zixYtyOp1KSkpS7969lZ+fr6amJr+v0dDQoLy8PPXs2VPJycmaN2+eLl26FJyzAQAA3UJ0oHe4/fbbtW/fvv98gej/fIm5c+fqD3/4g7Zt26b4+HgVFRVp0qRJ+uMf/yhJ6uzsVF5enux2uw4ePKhPPvlEjz/+uHr06KFly5YF4XQAAEB3EHCgREdHy263X7G/tbVVGzZsUEVFhcaMGSNJ2rhxowYOHKja2lqNHDlSe/fu1cmTJ7Vv3z7ZbDYNGzZMS5cu1fz587VkyRLFxMTc+BkBAICwF/A1KKdPn1ZKSopuu+02FRQUqKGhQZJUX1+vjo4O5eTk+NYOGDBAaWlpqqmpkSTV1NRo8ODBstlsvjW5ublyu906ceLENR/T4/HI7Xb7bQAAoPsKKFCys7O1adMm7dmzR2vXrtWZM2f03e9+V+fPn5fL5VJMTIwSEhL87mOz2eRyuSRJLpfLL04uH7987FrKysoUHx/v21JTUwMZGwAAhJmAvsUzfvx435+HDBmi7Oxspaen67XXXlNcXFzQh7uspKRExcXFvttut5tIAQCgG7uhjxknJCTo29/+tj788EPZ7Xa1t7erpaXFb01TU5PvmhW73X7Fp3ou377adS2XWSwWWa1Wvw0AAHRfNxQoFy5c0F//+lf17dtXw4cPV48ePVRVVeU7furUKTU0NMjhcEiSHA6Hjh8/rubmZt+ayspKWa1WZWZm3sgoAACgGwnoWzw//elPNWHCBKWnp+vs2bNavHixoqKi9Oijjyo+Pl7Tpk1TcXGxEhMTZbVaNWvWLDkcDo0cOVKSNG7cOGVmZmrKlClavny5XC6XFixYIKfTKYvF8pWcIAAACD8BBcrf/vY3Pfroo/rnP/+pW265RaNHj1Ztba1uueUWSdKKFSsUGRmp/Px8eTwe5ebmas2aNb77R0VFadeuXZo5c6YcDod69eqlwsJClZaWBvesAABAWAsoULZu3fq5x2NjY1VeXq7y8vJrrklPT9fu3bsDeVgAAPA/ht/FAwAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOPcUKA899xzioiI0Jw5c3z7Ll68KKfTqaSkJPXu3Vv5+flqamryu19DQ4Py8vLUs2dPJScna968ebp06dKNjAIAALqR6w6Uw4cP66WXXtKQIUP89s+dO1c7d+7Utm3bVF1drbNnz2rSpEm+452dncrLy1N7e7sOHjyozZs3a9OmTVq0aNH1nwUAAOhWritQLly4oIKCAr388su66aabfPtbW1u1YcMGvfjiixozZoyGDx+ujRs36uDBg6qtrZUk7d27VydPntTvfvc7DRs2TOPHj9fSpUtVXl6u9vb24JwVAAAIa9cVKE6nU3l5ecrJyfHbX19fr46ODr/9AwYMUFpammpqaiRJNTU1Gjx4sGw2m29Nbm6u3G63Tpw4cdXH83g8crvdfhsAAOi+ogO9w9atW3X06FEdPnz4imMul0sxMTFKSEjw22+z2eRyuXxr/jtOLh+/fOxqysrK9Itf/CLQUQEAQJgK6B2UxsZGzZ49W1u2bFFsbOxXNdMVSkpK1Nra6tsaGxu/tscGAABfv4ACpb6+Xs3NzbrzzjsVHR2t6OhoVVdXa9WqVYqOjpbNZlN7e7taWlr87tfU1CS73S5JstvtV3yq5/Lty2s+y2KxyGq1+m0AAKD7CihQxo4dq+PHj+vYsWO+LSsrSwUFBb4/9+jRQ1VVVb77nDp1Sg0NDXI4HJIkh8Oh48ePq7m52bemsrJSVqtVmZmZQTotAAAQzgK6BqVPnz4aNGiQ375evXopKSnJt3/atGkqLi5WYmKirFarZs2aJYfDoZEjR0qSxo0bp8zMTE2ZMkXLly+Xy+XSggUL5HQ6ZbFYgnRaAAAgnAV8kewXWbFihSIjI5Wfny+Px6Pc3FytWbPGdzwqKkq7du3SzJkz5XA41KtXLxUWFqq0tDTYowAAgDB1w4Hyzjvv+N2OjY1VeXm5ysvLr3mf9PR07d69+0YfGgAAdFP8Lh4AAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcaJDPYCJbn36D6EeQf/3XF6oRwAAIGR4BwUAABiHQAEAAMYJKFDWrl2rIUOGyGq1ymq1yuFw6M033/Qdv3jxopxOp5KSktS7d2/l5+erqanJ72s0NDQoLy9PPXv2VHJysubNm6dLly4F52wAAEC3EFCg9OvXT88995zq6+t15MgRjRkzRg8++KBOnDghSZo7d6527typbdu2qbq6WmfPntWkSZN89+/s7FReXp7a29t18OBBbd68WZs2bdKiRYuCe1YAACCsBXSR7IQJE/xuP/vss1q7dq1qa2vVr18/bdiwQRUVFRozZowkaePGjRo4cKBqa2s1cuRI7d27VydPntS+fftks9k0bNgwLV26VPPnz9eSJUsUExMTvDMDAABh67qvQens7NTWrVvV1tYmh8Oh+vp6dXR0KCcnx7dmwIABSktLU01NjSSppqZGgwcPls1m863Jzc2V2+32vQtzNR6PR263228DAADdV8CBcvz4cfXu3VsWi0U//vGPtX37dmVmZsrlcikmJkYJCQl+6202m1wulyTJ5XL5xcnl45ePXUtZWZni4+N9W2pqaqBjAwCAMBJwoHznO9/RsWPHVFdXp5kzZ6qwsFAnT578KmbzKSkpUWtrq29rbGz8Sh8PAACEVsA/qC0mJkbf+ta3JEnDhw/X4cOH9etf/1qPPPKI2tvb1dLS4vcuSlNTk+x2uyTJbrfr0KFDfl/v8qd8Lq+5GovFIovFEuioAACEHRN+WKgU+h8YesM/B6Wrq0sej0fDhw9Xjx49VFVV5Tt26tQpNTQ0yOFwSJIcDoeOHz+u5uZm35rKykpZrVZlZmbe6CgAAKCbCOgdlJKSEo0fP15paWk6f/68Kioq9M477+itt95SfHy8pk2bpuLiYiUmJspqtWrWrFlyOBwaOXKkJGncuHHKzMzUlClTtHz5crlcLi1YsEBOp5N3SAAAgE9AgdLc3KzHH39cn3zyieLj4zVkyBC99dZb+sEPfiBJWrFihSIjI5Wfny+Px6Pc3FytWbPGd/+oqCjt2rVLM2fOlMPhUK9evVRYWKjS0tLgnhUAAAhrAQXKhg0bPvd4bGysysvLVV5efs016enp2r17dyAPCwAA/sfwu3gAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYJ6BAKSsr01133aU+ffooOTlZEydO1KlTp/zWXLx4UU6nU0lJSerdu7fy8/PV1NTkt6ahoUF5eXnq2bOnkpOTNW/ePF26dOnGzwYAAHQLAQVKdXW1nE6namtrVVlZqY6ODo0bN05tbW2+NXPnztXOnTu1bds2VVdX6+zZs5o0aZLveGdnp/Ly8tTe3q6DBw9q8+bN2rRpkxYtWhS8swIAAGEtOpDFe/bs8bu9adMmJScnq76+Xt/73vfU2tqqDRs2qKKiQmPGjJEkbdy4UQMHDlRtba1GjhypvXv36uTJk9q3b59sNpuGDRumpUuXav78+VqyZIliYmKCd3YAACAs3dA1KK2trZKkxMRESVJ9fb06OjqUk5PjWzNgwAClpaWppqZGklRTU6PBgwfLZrP51uTm5srtduvEiRNXfRyPxyO32+23AQCA7uu6A6Wrq0tz5szRqFGjNGjQIEmSy+VSTEyMEhIS/NbabDa5XC7fmv+Ok8vHLx+7mrKyMsXHx/u21NTU6x0bAACEgesOFKfTqffee09bt24N5jxXVVJSotbWVt/W2Nj4lT8mAAAInYCuQbmsqKhIu3bt0oEDB9SvXz/ffrvdrvb2drW0tPi9i9LU1CS73e5bc+jQIb+vd/lTPpfXfJbFYpHFYrmeUQEAQBgK6B0Ur9eroqIibd++Xfv371dGRobf8eHDh6tHjx6qqqry7Tt16pQaGhrkcDgkSQ6HQ8ePH1dzc7NvTWVlpaxWqzIzM2/kXAAAQDcR0DsoTqdTFRUVev3119WnTx/fNSPx8fGKi4tTfHy8pk2bpuLiYiUmJspqtWrWrFlyOBwaOXKkJGncuHHKzMzUlClTtHz5crlcLi1YsEBOp5N3SQAAgKQAA2Xt2rWSpHvuucdv/8aNG/WjH/1IkrRixQpFRkYqPz9fHo9Hubm5WrNmjW9tVFSUdu3apZkzZ8rhcKhXr14qLCxUaWnpjZ0JAADoNgIKFK/X+4VrYmNjVV5ervLy8muuSU9P1+7duwN5aAAA8D+E38UDAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4wQcKAcOHNCECROUkpKiiIgI7dixw++41+vVokWL1LdvX8XFxSknJ0enT5/2W3Pu3DkVFBTIarUqISFB06ZN04ULF27oRAAAQPcRcKC0tbVp6NChKi8vv+rx5cuXa9WqVVq3bp3q6urUq1cv5ebm6uLFi741BQUFOnHihCorK7Vr1y4dOHBAM2bMuP6zAAAA3Up0oHcYP368xo8ff9VjXq9XK1eu1IIFC/Tggw9Kkl555RXZbDbt2LFDkydP1vvvv689e/bo8OHDysrKkiStXr1a999/v1544QWlpKTcwOkAAIDuIKjXoJw5c0Yul0s5OTm+ffHx8crOzlZNTY0kqaamRgkJCb44kaScnBxFRkaqrq4umOMAAIAwFfA7KJ/H5XJJkmw2m99+m83mO+ZyuZScnOw/RHS0EhMTfWs+y+PxyOPx+G673e5gjg0AAAwTFp/iKSsrU3x8vG9LTU0N9UgAAOArFNRAsdvtkqSmpia//U1NTb5jdrtdzc3NfscvXbqkc+fO+dZ8VklJiVpbW31bY2NjMMcGAACGCWqgZGRkyG63q6qqyrfP7Xarrq5ODodDkuRwONTS0qL6+nrfmv3796urq0vZ2dlX/boWi0VWq9VvAwAA3VfA16BcuHBBH374oe/2mTNndOzYMSUmJiotLU1z5szRM888o/79+ysjI0MLFy5USkqKJk6cKEkaOHCg7rvvPk2fPl3r1q1TR0eHioqKNHnyZD7BAwAAJF1HoBw5ckT33nuv73ZxcbEkqbCwUJs2bdJTTz2ltrY2zZgxQy0tLRo9erT27Nmj2NhY3322bNmioqIijR07VpGRkcrPz9eqVauCcDoAAKA7CDhQ7rnnHnm93msej4iIUGlpqUpLS6+5JjExURUVFYE+NAAA+B8RFp/iAQAA/1sIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABgnpIFSXl6uW2+9VbGxscrOztahQ4dCOQ4AADBEyALl1VdfVXFxsRYvXqyjR49q6NChys3NVXNzc6hGAgAAhghZoLz44ouaPn26pk6dqszMTK1bt049e/bUb3/721CNBAAADBEdigdtb29XfX29SkpKfPsiIyOVk5OjmpqaK9Z7PB55PB7f7dbWVkmS2+3+Subr8vzrK/m6gfiiczNhRok5g+nL/H0OhzlNmFFizmDi72ZwdZc5b+Rrer3eL17sDYGPP/7YK8l78OBBv/3z5s3zjhgx4or1ixcv9kpiY2NjY2Nj6wZbY2PjF7ZCSN5BCVRJSYmKi4t9t7u6unTu3DklJSUpIiIihJNdye12KzU1VY2NjbJaraEeJ+zxfAYPz2Vw8XwGD89lcJn8fHq9Xp0/f14pKSlfuDYkgXLzzTcrKipKTU1Nfvubmppkt9uvWG+xWGSxWPz2JSQkfJUj3jCr1WrcX4xwxvMZPDyXwcXzGTw8l8Fl6vMZHx//pdaF5CLZmJgYDR8+XFVVVb59XV1dqqqqksPhCMVIAADAICH7Fk9xcbEKCwuVlZWlESNGaOXKlWpra9PUqVNDNRIAADBEyALlkUce0d///nctWrRILpdLw4YN0549e2Sz2UI1UlBYLBYtXrz4im9J4frwfAYPz2Vw8XwGD89lcHWX5zPC6/0yn/UBAAD4+vC7eAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQgqy8vFy33nqrYmNjlZ2drUOHDoV6pLBTVlamu+66S3369FFycrImTpyoU6dOhXqsbuO5555TRESE5syZE+pRwtLHH3+sxx57TElJSYqLi9PgwYN15MiRUI8Vljo7O7Vw4UJlZGQoLi5O3/zmN7V06dIv93taoAMHDmjChAlKSUlRRESEduzY4Xfc6/Vq0aJF6tu3r+Li4pSTk6PTp0+HZtjrQKAE0auvvqri4mItXrxYR48e1dChQ5Wbm6vm5uZQjxZWqqur5XQ6VVtbq8rKSnV0dGjcuHFqa2sL9Whh7/Dhw3rppZc0ZMiQUI8Slj799FONGjVKPXr00JtvvqmTJ0/qV7/6lW666aZQjxaWnn/+ea1du1a/+c1v9P777+v555/X8uXLtXr16lCPFhba2to0dOhQlZeXX/X48uXLtWrVKq1bt051dXXq1auXcnNzdfHixa950usUjF/+h38bMWKE1+l0+m53dnZ6U1JSvGVlZSGcKvw1Nzd7JXmrq6tDPUpYO3/+vLd///7eyspK7/e//33v7NmzQz1S2Jk/f7539OjRoR6j28jLy/M+8cQTfvsmTZrkLSgoCNFE4UuSd/v27b7bXV1dXrvd7v3lL3/p29fS0uK1WCze3//+9yGYMHC8gxIk7e3tqq+vV05Ojm9fZGSkcnJyVFNTE8LJwl9ra6skKTExMcSThDen06m8vDy/v6MIzBtvvKGsrCw99NBDSk5O1h133KGXX3451GOFrbvvvltVVVX64IMPJEl//vOf9e6772r8+PEhniz8nTlzRi6Xy+/fe3x8vLKzs8PmNSksfptxOPjHP/6hzs7OK34Srs1m01/+8pcQTRX+urq6NGfOHI0aNUqDBg0K9Thha+vWrTp69KgOHz4c6lHC2kcffaS1a9equLhYP/vZz3T48GE9+eSTiomJUWFhYajHCztPP/203G63BgwYoKioKHV2durZZ59VQUFBqEcLey6XS5Ku+pp0+ZjpCBQYzel06r333tO7774b6lHCVmNjo2bPnq3KykrFxsaGepyw1tXVpaysLC1btkySdMcdd+i9997TunXrCJTr8Nprr2nLli2qqKjQ7bffrmPHjmnOnDlKSUnh+QQXyQbLzTffrKioKDU1Nfntb2pqkt1uD9FU4a2oqEi7du3S22+/rX79+oV6nLBVX1+v5uZm3XnnnYqOjlZ0dLSqq6u1atUqRUdHq7OzM9Qjho2+ffsqMzPTb9/AgQPV0NAQoonC27x58/T0009r8uTJGjx4sKZMmaK5c+eqrKws1KOFvcuvO+H8mkSgBElMTIyGDx+uqqoq376uri5VVVXJ4XCEcLLw4/V6VVRUpO3bt2v//v3KyMgI9UhhbezYsTp+/LiOHTvm27KyslRQUKBjx44pKioq1COGjVGjRl3xkfcPPvhA6enpIZoovP3rX/9SZKT/y1BUVJS6urpCNFH3kZGRIbvd7vea5Ha7VVdXFzavSXyLJ4iKi4tVWFiorKwsjRgxQitXrlRbW5umTp0a6tHCitPpVEVFhV5//XX16dPH9/3S+Ph4xcXFhXi68NOnT58rrt/p1auXkpKSuK4nQHPnztXdd9+tZcuW6eGHH9ahQ4e0fv16rV+/PtSjhaUJEybo2WefVVpamm6//Xb96U9/0osvvqgnnngi1KOFhQsXLujDDz/03T5z5oyOHTumxMREpaWlac6cOXrmmWfUv39/ZWRkaOHChUpJSdHEiRNDN3QgQv0xou5m9erV3rS0NG9MTIx3xIgR3tra2lCPFHYkXXXbuHFjqEfrNviY8fXbuXOnd9CgQV6LxeIdMGCAd/369aEeKWy53W7v7NmzvWlpad7Y2Fjvbbfd5v35z3/u9Xg8oR4tLLz99ttX/b+ysLDQ6/X++6PGCxcu9NpsNq/FYvGOHTvWe+rUqdAOHYAIr5cf2QcAAMzCNSgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADj/D+QRB0GMpYi7QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "t = TSTensor(torch.arange(11).float())\n",
    "tt_ = []\n",
    "for _ in range(1000):\n",
    "    tt = TSShuffleSteps()(t, split_idx=0)\n",
    "    test_eq(len(set(tt.tolist())), len(t))\n",
    "    test_ne(tt, t)\n",
    "    tt_.extend([t for i,t in enumerate(tt) if t!=i])\n",
    "x, y = np.unique(tt_, return_counts=True) # This is to visualize distribution which should be equal for all and half for first and last items\n",
    "plt.bar(x, y);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSGaussianNoise(RandTransform):\n",
    "    \"Applies additive or multiplicative gaussian noise\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=.5, additive=True, ex=None, **kwargs):\n",
    "        self.magnitude, self.additive, self.ex = magnitude, additive, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if self.magnitude <= 0: return o\n",
    "        noise = self.magnitude * torch.randn_like(o)\n",
    "        if self.ex is None:\n",
    "            if self.additive: return o + noise\n",
    "            else: return o * (1 + noise)\n",
    "        else:\n",
    "            if self.additive: output = o + noise\n",
    "            else: output = o * (1 + noise)\n",
    "        output[..., self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSGaussianNoise(.1, additive=True)(xb, split_idx=0).shape, xb.shape)\n",
    "test_eq(TSGaussianNoise(.1, additive=False)(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSMagAddNoise(RandTransform):\n",
    "    \"Applies additive noise on the y-axis for each step of a `TSTensor` batch\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        # output = o + torch.normal(0, o.std() * self.magnitude, o.shape, dtype=o.dtype, device=o.device)\n",
    "        output = o + torch.normal(0, 1/3, o.shape, dtype=o.dtype, device=o.device) * (o[..., 1:] - o[..., :-1]).std(2, keepdims=True) * self.magnitude\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output\n",
    "\n",
    "class TSMagMulNoise(RandTransform):\n",
    "    \"Applies multiplicative noise on the y-axis for each step of a `TSTensor` batch\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        noise = torch.normal(1, self.magnitude * .025, o.shape, dtype=o.dtype, device=o.device)\n",
    "        output = o * noise\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSMagAddNoise()(xb, split_idx=0).shape, xb.shape)\n",
    "test_eq(TSMagMulNoise()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSMagAddNoise()(xb, split_idx=0), xb)\n",
    "test_ne(TSMagMulNoise()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def random_curve_generator(o, magnitude=0.1, order=4, noise=None):\n",
    "    seq_len = o.shape[-1]\n",
    "    f = CubicSpline(np.linspace(-seq_len, 2 * seq_len - 1, 3 * (order - 1) + 1, dtype=int),\n",
    "                    np.random.normal(loc=1.0, scale=magnitude, size=3 * (order - 1) + 1), axis=-1)\n",
    "    return f(np.arange(seq_len))\n",
    "\n",
    "def random_cum_curve_generator(o, magnitude=0.1, order=4, noise=None):\n",
    "    x = random_curve_generator(o, magnitude=magnitude, order=order, noise=noise).cumsum()\n",
    "    x -= x[0]\n",
    "    x /= x[-1]\n",
    "    x = np.clip(x, 0, 1)\n",
    "    return x * (o.shape[-1] - 1)\n",
    "\n",
    "def random_cum_noise_generator(o, magnitude=0.1, noise=None):\n",
    "    seq_len = o.shape[-1]\n",
    "    x = np.clip(np.ones(seq_len) + np.random.normal(loc=0, scale=magnitude, size=seq_len), 0, 1000).cumsum()\n",
    "    x -= x[0]\n",
    "    x /= x[-1]\n",
    "    return x * (o.shape[-1] - 1)\n",
    "\n",
    "def random_cum_linear_generator(o, magnitude=0.1):\n",
    "    seq_len = o.shape[-1]\n",
    "    win_len = int(round(seq_len * np.random.rand() * magnitude))\n",
    "    if win_len == seq_len: return np.arange(o.shape[-1])\n",
    "    start = np.random.randint(0, seq_len - win_len)\n",
    "    # mult between .5 and 2\n",
    "    rand = np.random.rand()\n",
    "    mult = 1 + rand\n",
    "    if np.random.randint(2): mult = 1 - rand/2\n",
    "    x = np.ones(seq_len)\n",
    "    x[start : start + win_len] = mult\n",
    "    x = x.cumsum()\n",
    "    x -= x[0]\n",
    "    x /= x[-1]\n",
    "    return np.clip(x, 0, 1) * (seq_len - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSTimeNoise(RandTransform):\n",
    "    \"Applies noise to each step in the x-axis of a `TSTensor` batch based on smooth random curve\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        f = CubicSpline(np.arange(o.shape[-1]), o.cpu(), axis=-1)\n",
    "        output = o.new(f(random_cum_noise_generator(o, magnitude=self.magnitude)))\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSTimeNoise()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSTimeNoise()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSMagWarp(RandTransform):\n",
    "    \"Applies warping to the y-axis of a `TSTensor` batch based on a smooth random curve\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.02, ord=4, ex=None, **kwargs):\n",
    "        self.magnitude, self.ord, self.ex = magnitude, ord, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if self.magnitude and self.magnitude <= 0: return o\n",
    "        y_mult = random_curve_generator(o, magnitude=self.magnitude, order=self.ord)\n",
    "        output = o * o.new(y_mult)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSMagWarp()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSMagWarp()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSTimeWarp(RandTransform):\n",
    "    \"Applies time warping to the x-axis of a `TSTensor` batch based on a smooth random curve\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ord=6, ex=None, **kwargs):\n",
    "        self.magnitude, self.ord, self.ex = magnitude, ord, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        f = CubicSpline(np.arange(o.shape[-1]), o.cpu(), axis=-1)\n",
    "        output = o.new(f(random_cum_curve_generator(o, magnitude=self.magnitude, order=self.ord)))\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSTimeWarp()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSTimeWarp()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSWindowWarp(RandTransform):\n",
    "    \"\"\"Applies window slicing to the x-axis of a `TSTensor` batch based on a random linear curve based on\n",
    "    https://halshs.archives-ouvertes.fr/halshs-01357973/document\"\"\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0 or self.magnitude >= 1: return o\n",
    "        f = CubicSpline(np.arange(o.shape[-1]), o.cpu(), axis=-1)\n",
    "        output = o.new(f(random_cum_linear_generator(o, magnitude=self.magnitude)))\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSWindowWarp()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSMagScale(RandTransform):\n",
    "    \"Applies scaling to the y-axis of a `TSTensor` batch based on a scalar\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.5, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        rand = random_half_normal()\n",
    "        scale = (1 - (rand  * self.magnitude)/2) if random.random() > 1/3 else (1 + (rand  * self.magnitude))\n",
    "        output = o * scale\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output\n",
    "\n",
    "class TSMagScalePerVar(RandTransform):\n",
    "    \"Applies per_var scaling to the y-axis of a `TSTensor` batch based on a scalar\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.5, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        s = [1] * o.ndim\n",
    "        s[-2] = o.shape[-2]\n",
    "        rand = random_half_normal_tensor(s, device=o.device)\n",
    "        scale = (1 - (rand  * self.magnitude)/2) if random.random() > 1/3 else (1 + (rand  * self.magnitude))\n",
    "        output = o * scale\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output\n",
    "\n",
    "TSMagScaleByVar = TSMagScalePerVar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSMagScale()(xb, split_idx=0).shape, xb.shape)\n",
    "test_eq(TSMagScalePerVar()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSMagScale()(xb, split_idx=0), xb)\n",
    "test_ne(TSMagScalePerVar()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def test_interpolate(mode=\"linear\"):\n",
    "\n",
    "    assert mode in [\"nearest\", \"linear\", \"area\"], \"Mode must be 'nearest', 'linear' or 'area'.\"\n",
    "\n",
    "    # Create a 1D tensor\n",
    "    tensor = torch.randn(1, 1, 8, device=default_device())\n",
    "\n",
    "    try:\n",
    "        # Try to interpolate using linear mode\n",
    "        result = F.interpolate(tensor, scale_factor=2, mode=mode)\n",
    "        return True\n",
    "    except NotImplementedError as e:\n",
    "        print(f\"{mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "        print(\"Error:\", e)\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear interpolation is not supported by mps. You can try a different mode\n",
      "Error: The operator 'aten::upsample_linear1d.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run the test\n",
    "test_interpolate('linear')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_interpolate('nearest')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomResizedCrop(RandTransform):\n",
    "    \"Randomly amplifies a sequence focusing on a random section of the steps\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, size=None, scale=None, ex=None, mode='nearest', **kwargs):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            size: None, int or float\n",
    "            scale: None or tuple of 2 floats 0 < float <= 1\n",
    "            mode:  'nearest' | 'linear' | 'area'\n",
    "\n",
    "        \"\"\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        if scale is not None:\n",
    "            assert is_listy(scale) and len(scale) == 2 and min(scale) > 0 and min(scale) <= 1, \"scale must be a tuple with 2 floats 0 < float <= 1\"\n",
    "        self.size,self.scale = size,scale\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        if self.size is not None:\n",
    "            size = self.size if isinstance(self.size, Integral) else int(round(self.size * seq_len))\n",
    "        else:\n",
    "            size = seq_len\n",
    "        if self.scale is not None:\n",
    "            lambd = np.random.uniform(self.scale[0], self.scale[1])\n",
    "        else:\n",
    "            lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "            lambd = max(lambd, 1 - lambd)\n",
    "        win_len = int(round(seq_len * lambd))\n",
    "        if win_len == seq_len:\n",
    "            if size == seq_len: return o\n",
    "            _slice = slice(None)\n",
    "        else:\n",
    "            start = np.random.randint(0, seq_len - win_len)\n",
    "            _slice = slice(start, start + win_len)\n",
    "        return F.interpolate(o[..., _slice], size=size, mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "\n",
    "TSRandomZoomIn = TSRandomResizedCrop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSRandomResizedCrop(.5)(xb, split_idx=0).shape, xb.shape)\n",
    "    test_ne(TSRandomResizedCrop(size=.8, scale=(.5, 1))(xb, split_idx=0).shape, xb.shape)\n",
    "    test_ne(TSRandomResizedCrop(size=20, scale=(.5, 1))(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSWindowSlicing(RandTransform):\n",
    "    \"Randomly extracts an resize a ts slice based on https://halshs.archives-ouvertes.fr/halshs-01357973/document\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0 or self.magnitude >= 1: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        win_len = int(round(seq_len * (1 - self.magnitude)))\n",
    "        if win_len == seq_len: return o\n",
    "        start = np.random.randint(0, seq_len - win_len)\n",
    "        return F.interpolate(o[..., start : start + win_len], size=seq_len, mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSWindowSlicing()(xb, split_idx=0).shape, xb.shape)\n",
    "    test_ne(TSWindowSlicing()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomZoomOut(RandTransform):\n",
    "    \"Randomly compresses a sequence on the x-axis\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = max(lambd, 1 - lambd)\n",
    "        win_len = int(round(seq_len * lambd))\n",
    "        if win_len == seq_len: return o\n",
    "        start = (seq_len - win_len) // 2\n",
    "        output = torch.zeros_like(o, dtype=o.dtype, device=o.device)\n",
    "        interp = F.interpolate(o, size=win_len, mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        output[..., start:start + win_len] = o.new(interp)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSRandomZoomOut(.5)(xb, split_idx=0).shape, xb.shape)#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomTimeScale(RandTransform):\n",
    "    \"Randomly amplifies/ compresses a sequence on the x-axis keeping the same length\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        if np.random.rand() <= 0.5: return TSRandomZoomIn(magnitude=self.magnitude, ex=self.ex, mode=self.mode)(o, split_idx=0)\n",
    "        else: return TSRandomZoomOut(magnitude=self.magnitude, ex=self.ex, mode=self.mode)(o, split_idx=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSRandomTimeScale(.5)(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomTimeStep(RandTransform):\n",
    "    \"Compresses a sequence on the x-axis by randomly selecting sequence steps and interpolating to previous size\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.02, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        new_seq_len = int(round(seq_len * max(.5, (1 - np.random.rand() * self.magnitude))))\n",
    "        if  new_seq_len == seq_len: return o\n",
    "        timesteps = np.sort(random_choice(np.arange(seq_len),new_seq_len, replace=False))\n",
    "        output = F.interpolate(o[..., timesteps], size=seq_len, mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSRandomTimeStep()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "\n",
    "class TSResampleSteps(RandTransform):\n",
    "    \"Transform that randomly selects and sorts sequence steps (with replacement) maintaining the sequence length\"\n",
    "\n",
    "    order = 90\n",
    "    def __init__(self, step_pct=1., same_seq_len=True, magnitude=None, **kwargs):\n",
    "        assert step_pct > 0, 'seq_len_pct must be subsample > 0'\n",
    "        self.step_pct, self.same_seq_len = step_pct, same_seq_len\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def encodes(self, o: TSTensor):\n",
    "        S = o.shape[-1]\n",
    "        if isinstance(self.step_pct, tuple):\n",
    "            step_pct = np.random.rand() * (self.step_pct[1] - self.step_pct[0]) + self.step_pct[0]\n",
    "        else:\n",
    "            step_pct = self.step_pct\n",
    "        if step_pct != 1 and self.same_seq_len:\n",
    "            idxs = np.sort(np.tile(random_choice(S, round(S * step_pct), True), math.ceil(1 / step_pct))[:S])\n",
    "        else:\n",
    "            idxs = np.sort(random_choice(S, round(S * step_pct), True))\n",
    "        return o[..., idxs]\n",
    "\n",
    "TSSubsampleSteps = TSResampleSteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSResampleSteps(step_pct=.9, same_seq_len=False)(xb, split_idx=0).shape[-1], round(.9*xb.shape[-1]))\n",
    "test_eq(TSResampleSteps(step_pct=.9, same_seq_len=True)(xb, split_idx=0).shape[-1], xb.shape[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSBlur(RandTransform):\n",
    "    \"Blurs a sequence applying a filter of type [1, 0, 1]\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1., ex=None, filt_len=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        if filt_len is None:\n",
    "            filterargs = [1, 0, 1]\n",
    "        else:\n",
    "            filterargs = ([1] * max(1, filt_len // 2) + [0] + [1] * max(1, filt_len // 2))\n",
    "        self.filterargs = np.array(filterargs)\n",
    "        self.filterargs = self.filterargs/self.filterargs.sum()\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        output = o.new(convolve1d(o.cpu(), self.filterargs, mode='nearest'))\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSBlur(filt_len=7)(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSBlur()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSSmooth(RandTransform):\n",
    "    \"Smoothens a sequence applying a filter of type [1, 5, 1]\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1., ex=None, filt_len=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        self.filterargs = np.array([1, 5, 1])\n",
    "        if filt_len is None:\n",
    "            filterargs = [1, 5, 1]\n",
    "        else:\n",
    "            filterargs = ([1] * max(1, filt_len // 2) + [5] + [1] * max(1, filt_len // 2))\n",
    "        self.filterargs = np.array(filterargs)\n",
    "        self.filterargs = self.filterargs/self.filterargs.sum()\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        output = o.new(convolve1d(o.cpu(), self.filterargs, mode='nearest'))\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSSmooth(filt_len=7)(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSSmooth()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def maddest(d, axis=None):\n",
    "    #Mean Absolute Deviation\n",
    "    return np.mean(np.absolute(d - np.mean(d, axis=axis)), axis=axis)\n",
    "\n",
    "class TSFreqDenoise(RandTransform):\n",
    "    \"Denoises a sequence applying a wavelet decomposition method\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, wavelet='db4', level=2, thr=None, thr_mode='hard', pad_mode='per', **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        self.wavelet, self.level, self.thr, self.thr_mode, self.pad_mode = wavelet, level, thr, thr_mode, pad_mode\n",
    "        super().__init__(**kwargs)\n",
    "        try:\n",
    "            import pywt\n",
    "        except ImportError:\n",
    "            raise ImportError('You need to install pywt to run TSFreqDenoise')\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        \"\"\"\n",
    "        1. Adapted from waveletSmooth function found here:\n",
    "        http://connor-johnson.com/2016/01/24/using-pywavelets-to-remove-high-frequency-noise/\n",
    "        2. Threshold equation and using hard mode in threshold as mentioned\n",
    "        in section '3.2 denoising based on optimized singular values' from paper by Tomas Vantuch:\n",
    "        http://dspace.vsb.cz/bitstream/handle/10084/133114/VAN431_FEI_P1807_1801V001_2018.pdf\n",
    "        \"\"\"\n",
    "        seq_len = o.shape[-1]\n",
    "        # Decompose to get the wavelet coefficients\n",
    "        coeff = pywt.wavedec(o.cpu(), self.wavelet, mode=self.pad_mode)\n",
    "        # Calculate sigma for threshold as defined in http://dspace.vsb.cz/bitstream/handle/10084/133114/VAN431_FEI_P1807_1801V001_2018.pdf\n",
    "        # As noted by @harshit92 MAD referred to in the paper is Mean Absolute Deviation not Median Absolute Deviation\n",
    "        sigma = (1/0.6745) * maddest(coeff[-self.level])\n",
    "        # Calculate the univeral threshold\n",
    "        uthr = sigma * np.sqrt(2*np.log(seq_len)) * (1 if self.thr is None else self.magnitude)\n",
    "        coeff[1:] = (pywt.threshold(c, value=uthr, mode=self.thr_mode) for c in coeff[1:])\n",
    "        # Reconstruct the signal using the thresholded coefficients\n",
    "        output = o.new(pywt.waverec(coeff, self.wavelet, mode=self.pad_mode)[..., :seq_len])\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try: import pywt\n",
    "except ImportError: pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'pywt' in dir():\n",
    "    test_eq(TSFreqDenoise()(xb, split_idx=0).shape, xb.shape)\n",
    "    test_ne(TSFreqDenoise()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomFreqNoise(RandTransform):\n",
    "    \"Applys random noise using a wavelet decomposition method\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, wavelet='db4', level=2, mode='constant', **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        self.wavelet, self.level, self.mode = wavelet, level, mode\n",
    "        super().__init__(**kwargs)\n",
    "        try:\n",
    "            import pywt\n",
    "        except ImportError:\n",
    "            raise ImportError('You need to install pywt to run TSRandomFreqNoise')\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        self.level = 1 if self.level is None else self.level\n",
    "        coeff = pywt.wavedec(o.cpu(), self.wavelet, mode=self.mode, level=self.level)\n",
    "        coeff[1:] = [c * (1 + 2 * (np.random.rand() - 0.5) * self.magnitude) for c in coeff[1:]]\n",
    "        output = o.new(pywt.waverec(coeff, self.wavelet, mode=self.mode)[..., :o.shape[-1]])\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'pywt' in dir():\n",
    "    test_eq(TSRandomFreqNoise()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomResizedLookBack(RandTransform):\n",
    "    \"Selects a random number of sequence steps starting from the end and return an output of the same shape\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.mode = magnitude, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = min(lambd, 1 - lambd)\n",
    "        output = o.clone()[..., int(round(lambd * seq_len)):]\n",
    "        return F.interpolate(output, size=seq_len, mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    for i in range(100):\n",
    "        o = TSRandomResizedLookBack()(xb, split_idx=0)\n",
    "        test_eq(o.shape[-1], xb.shape[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomLookBackOut(RandTransform):\n",
    "    \"Selects a random number of sequence steps starting from the end and set them to zero\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, **kwargs):\n",
    "        self.magnitude = magnitude\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = min(lambd, 1 - lambd)\n",
    "        output = o.clone()\n",
    "        output[..., :int(round(lambd * seq_len))] = 0\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    o = TSRandomLookBackOut()(xb, split_idx=0)\n",
    "    test_eq(o.shape[-1], xb.shape[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSVarOut(RandTransform):\n",
    "    \"Set the value of a random number of variables to zero\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.05, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        in_vars = o.shape[-2]\n",
    "        if in_vars == 1: return o\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = min(lambd, 1 - lambd)\n",
    "        p = np.arange(in_vars).cumsum()\n",
    "        p = p/p[-1]\n",
    "        p = p / p.sum()\n",
    "        p = p[::-1]\n",
    "        out_vars = random_choice(np.arange(in_vars), int(round(lambd * in_vars)), p=p, replace=False)\n",
    "        if len(out_vars) == 0:  return o\n",
    "        output = o.clone()\n",
    "        output[...,out_vars,:] = 0\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSVarOut()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSCutOut(RandTransform):\n",
    "    \"Sets a random section of the sequence to zero\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.05, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = min(lambd, 1 - lambd)\n",
    "        win_len = int(round(seq_len * lambd))\n",
    "        start = np.random.randint(-win_len + 1, seq_len)\n",
    "        end = start + win_len\n",
    "        start = max(0, start)\n",
    "        end = min(end, seq_len)\n",
    "        output = o.clone()\n",
    "        output[..., start:end] = 0\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSCutOut()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSTimeStepOut(RandTransform):\n",
    "    \"Sets random sequence steps to zero\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.05, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        magnitude = min(.5, self.magnitude)\n",
    "        seq_len = o.shape[-1]\n",
    "        timesteps = np.sort(random_choice(np.arange(seq_len), int(round(seq_len * magnitude)), replace=False))\n",
    "        output = o.clone()\n",
    "        output[..., timesteps] = 0\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSTimeStepOut()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomCropPad(RandTransform):\n",
    "    \"Crops a section of the sequence of a random length\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.05, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = max(lambd, 1 - lambd)\n",
    "        win_len = int(round(seq_len * lambd))\n",
    "        if win_len == seq_len: return o\n",
    "        start = np.random.randint(0, seq_len - win_len)\n",
    "        output = torch.zeros_like(o, dtype=o.dtype, device=o.device)\n",
    "        output[..., start : start + win_len] = o[..., start : start + win_len]\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSRandomCropPad()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSMaskOut(RandTransform):\n",
    "    \"\"\"Applies a random mask\"\"\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, compensate:bool=False, ex=None, **kwargs):\n",
    "        store_attr()\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        mask = torch.rand_like(o) > (1 - self.magnitude)\n",
    "        if self.compensate: # per sample and feature\n",
    "            mean_per_seq = (torch.max(torch.ones(1, device=mask.device), torch.sum(mask, dim=-1).unsqueeze(-1)) / mask.shape[-1])\n",
    "            output = o.masked_fill(mask, 0) / (1 - mean_per_seq)\n",
    "        else:\n",
    "            output = o.masked_fill(mask, 0)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSMaskOut()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSMaskOut()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSInputDropout(RandTransform):\n",
    "    \"\"\"Applies input dropout with required_grad=False\"\"\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0., ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        self.dropout = nn.Dropout(magnitude)\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        output = self.dropout(o)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSInputDropout(.1)(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSInputDropout(.1)(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSTranslateX(RandTransform):\n",
    "    \"Moves a selected sequence window a random number of steps\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        seq_len = o.shape[-1]\n",
    "        lambd = np.random.beta(self.magnitude, self.magnitude)\n",
    "        lambd = min(lambd, 1 - lambd)\n",
    "        shift = int(round(seq_len * lambd))\n",
    "        if shift == 0 or shift == seq_len: return o\n",
    "        if np.random.rand() < 0.5: shift = -shift\n",
    "        new_start = max(0, shift)\n",
    "        new_end = min(seq_len + shift, seq_len)\n",
    "        start = max(0, -shift)\n",
    "        end = min(seq_len - shift, seq_len)\n",
    "        output = torch.zeros_like(o, dtype=o.dtype, device=o.device)\n",
    "        output[..., new_start : new_end] = o[..., start : end]\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSTranslateX()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomShift(RandTransform):\n",
    "    \"Shifts and splits a sequence\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.02, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        pos = int(round(np.random.randint(0, o.shape[-1]) * self.magnitude)) * (random.randint(0, 1)*2-1)\n",
    "        output = torch.cat((o[..., pos:], o[..., :pos]), dim=-1)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSRandomShift()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSHorizontalFlip(RandTransform):\n",
    "    \"Flips the sequence along the x-axis\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1., ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        output = torch.flip(o, [-1])\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSHorizontalFlip()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSHorizontalFlip()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomTrend(RandTransform):\n",
    "    \"Randomly rotates the sequence along the z-axis\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        flat_x = o.reshape(o.shape[0], -1)\n",
    "        ran = flat_x.max(dim=-1, keepdim=True)[0] - flat_x.min(dim=-1, keepdim=True)[0]\n",
    "        trend = torch.linspace(0, 1, o.shape[-1], device=o.device) * ran\n",
    "        t = (1 + self.magnitude * 2 * (np.random.rand() - 0.5) * trend)\n",
    "        t -= t.mean(-1, keepdim=True)\n",
    "        if o.ndim == 3: t = t.unsqueeze(1)\n",
    "        output = o + t\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output\n",
    "\n",
    "TSRandomRotate = TSRandomTrend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSRandomTrend()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSVerticalFlip(RandTransform):\n",
    "    \"Applies a negative value to the time sequence\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=1., ex=None, **kwargs):\n",
    "        self.magnitude, self.ex = magnitude, ex\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        return - o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TSVerticalFlip()(xb, split_idx=0).shape, xb.shape)\n",
    "test_ne(TSVerticalFlip()(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSResize(RandTransform):\n",
    "    \"Resizes the sequence length of a time series\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=-0.5, size=None, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.size, self.ex, self.mode = magnitude, size, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if self.magnitude == 0: return o\n",
    "        size = ifnone(self.size, int(round((1 + self.magnitude) * o.shape[-1])))\n",
    "        output = F.interpolate(o, size=size, mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    for sz in np.linspace(.2, 2, 10): test_eq(TSResize(sz)(xb, split_idx=0).shape[-1], int(round(xb.shape[-1]*(1+sz))))\n",
    "    test_ne(TSResize(1)(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomSize(RandTransform):\n",
    "    \"Randomly resizes the sequence length of a time series\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        size_perc = 1 + random_half_normal() * self.magnitude * (-1 if random.random() > .5 else 1)\n",
    "        return F.interpolate(o, size=int(size_perc * o.shape[-1]), mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    seq_len_ = []\n",
    "    for i in range(100):\n",
    "        o = TSRandomSize(.5)(xb, split_idx=0)\n",
    "        seq_len_.append(o.shape[-1])\n",
    "    test_lt(min(seq_len_), xb.shape[-1])\n",
    "    test_gt(max(seq_len_), xb.shape[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomLowRes(RandTransform):\n",
    "    \"Randomly resizes the sequence length of a time series to a lower resolution\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=.5, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        size_perc = 1 - (np.random.rand() * (1 - self.magnitude))\n",
    "        return F.interpolate(o, size=int(size_perc * o.shape[-1]), mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSDownUpScale(RandTransform):\n",
    "    \"Downscales a time series and upscales it again to previous sequence length\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.5, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0 or self.magnitude >= 1: return o\n",
    "        output = F.interpolate(o, size=int((1 - self.magnitude) * o.shape[-1]), mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        output = F.interpolate(output, size=o.shape[-1], mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSDownUpScale()(xb, split_idx=0).shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomDownUpScale(RandTransform):\n",
    "    \"Randomly downscales a time series and upscales it again to previous sequence length\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=.5, ex=None, mode='nearest', **kwargs):\n",
    "        \"mode:  'nearest' | 'linear' | 'area'\"\n",
    "\n",
    "        if not test_interpolate(mode):\n",
    "            print(f\"self.__name__ will not be applied because {mode} interpolation is not supported by {default_device()}. You can try a different mode\")\n",
    "            magnitude = 0\n",
    "\n",
    "        self.magnitude, self.ex, self.mode = magnitude, ex, mode\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0 or self.magnitude >= 1: return o\n",
    "        scale_factor = 0.5 + 0.5 * np.random.rand()\n",
    "        output = F.interpolate(o, size=int(scale_factor * o.shape[-1]), mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        output = F.interpolate(output, size=o.shape[-1], mode=self.mode, align_corners=None if self.mode in ['nearest', 'area'] else False)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_interpolate('nearest'):\n",
    "    test_eq(TSRandomDownUpScale()(xb, split_idx=0).shape, xb.shape)\n",
    "    test_ne(TSDownUpScale()(xb, split_idx=0), xb)\n",
    "    test_eq(TSDownUpScale()(xb, split_idx=1), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSRandomConv(RandTransform):\n",
    "    \"\"\"Applies a convolution with a random kernel and random weights with required_grad=False\"\"\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.05, ex=None, ks=[1, 3, 5, 7], **kwargs):\n",
    "        self.magnitude, self.ex, self.ks = magnitude, ex, ks\n",
    "        self.conv = nn.Conv1d(1, 1, 1, bias=False)\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0 or self.ks is None: return o\n",
    "        ks = random_choice(self.ks, 1)[0] if is_listy(self.ks) else self.ks\n",
    "        c_in = o.shape[1]\n",
    "        weight = nn.Parameter(torch.zeros(c_in, c_in, ks, device=o.device, requires_grad=False))\n",
    "        nn.init.kaiming_normal_(weight)\n",
    "        self.conv.weight = weight\n",
    "        self.conv.padding = ks // 2\n",
    "        output = (1 - self.magnitude) * o + self.magnitude * self.conv(o)\n",
    "        if self.ex is not None: output[...,self.ex,:] = o[...,self.ex,:]\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    o = TSRandomConv(magnitude=0.05, ex=None, ks=[1, 3, 5, 7])(xb, split_idx=0)\n",
    "    test_eq(o.shape, xb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TSRandom2Value(RandTransform):\n",
    "    \"Randomly sets selected variables of type `TSTensor` to predefined value (default: np.nan)\"\n",
    "    order = 90\n",
    "    def __init__(self, magnitude=0.1, sel_vars=None, sel_steps=None, static=False, value=np.nan, **kwargs):\n",
    "        assert not (sel_steps is not None and static), \"you must choose either static or sel_steps\"\n",
    "        if is_listy(sel_vars) and is_listy(sel_steps):\n",
    "            sel_vars = np.asarray(sel_vars)[:, None]\n",
    "        self.sel_vars, self.sel_steps = sel_vars, sel_steps\n",
    "        if sel_vars is None:\n",
    "            self._sel_vars = slice(None)\n",
    "        else:\n",
    "            self._sel_vars = sel_vars\n",
    "        if sel_steps is None or static:\n",
    "            self._sel_steps = slice(None)\n",
    "        else:\n",
    "            self._sel_steps = sel_steps\n",
    "        self.magnitude, self.static, self.value = magnitude , static, value\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def encodes(self, o:TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0 or self.magnitude > 1: return o\n",
    "        if self.static:\n",
    "            if self.sel_vars is not None:\n",
    "                if self.magnitude == 1:\n",
    "                    o[:, self._sel_vars] = o[:, self._sel_vars].fill_(self.value)\n",
    "                    return o\n",
    "                else:\n",
    "                    vals = torch.zeros_like(o)\n",
    "                    vals[:, self._sel_vars] = torch.rand(*vals[:, self._sel_vars, 0].shape, device=o.device).unsqueeze(-1)\n",
    "            else:\n",
    "                if self.magnitude == 1:\n",
    "                    return o.fill_(self.value)\n",
    "                else:\n",
    "                    vals = torch.rand(*o.shape[:-1], device=o.device).unsqueeze(-1)\n",
    "        elif self.sel_vars is not None or self.sel_steps is not None:\n",
    "            if self.magnitude == 1:\n",
    "                o[:, self._sel_vars, self._sel_steps] = o[:, self._sel_vars, self._sel_steps].fill_(self.value)\n",
    "                return o\n",
    "            else:\n",
    "                vals = torch.zeros_like(o)\n",
    "                vals[:, self._sel_vars, self._sel_steps] = torch.rand(*vals[:, self._sel_vars, self._sel_steps].shape, device=o.device)\n",
    "        else:\n",
    "            if self.magnitude == 1:\n",
    "                return o.fill_(self.value)\n",
    "            else:\n",
    "                vals = torch.rand_like(o)\n",
    "        mask = vals > (1 - self.magnitude)\n",
    "        return o.masked_fill(mask, self.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 0., 1., 0., 1., 1., 0., 1., 1., 0.],\n",
       "         [1., 1., 0., 1., 1., 1., 1., 1., 1., 0.],\n",
       "         [1., 1., 1., 1., 1., 0., 0., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 0., 1., 1., 0., 1.],\n",
       "         [0., 0., 0., 0., 0., 1., 0., 1., 0., 1.],\n",
       "         [0., 1., 0., 1., 0., 0., 0., 1., 0., 0.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=0.5, sel_vars=None, sel_steps=None, static=False, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 0., 1., 0., 0., 0.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 0., 1., 0., 0., 0.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=0.5, sel_vars=[1], sel_steps=slice(-5, None), static=False, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=0.5, sel_vars=[1], sel_steps=None, static=True, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=1, sel_vars=1, sel_steps=None, static=False, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=1, sel_vars=[1,2], sel_steps=None, static=False, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 0., 1., 0., 1., 0., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 0., 1., 0., 1., 0., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=1, sel_vars=1, sel_steps=[1,3,5], static=False, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 0., 1., 0., 1., 0., 1., 1., 1., 1.],\n",
       "         [1., 0., 1., 0., 1., 0., 1., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 0., 1., 0., 1., 0., 1., 1., 1., 1.],\n",
       "         [1., 0., 1., 0., 1., 0., 1., 1., 1., 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2, 3, 10))\n",
    "TSRandom2Value(magnitude=1, sel_vars=[1,2], sel_steps=[1,3,5], static=False, value=0)(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., nan, nan, 1.],\n",
       "         [1., 1., 1., 1.],\n",
       "         [1., nan, 1., 1.]],\n",
       "\n",
       "        [[nan, 1., 1., nan],\n",
       "         [1., 1., 1., 1.],\n",
       "         [nan, nan, 1., 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2,3,4))\n",
    "TSRandom2Value(magnitude=.5, sel_vars=[0,2])(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., nan],\n",
       "         [1., 1., nan, 1.],\n",
       "         [1., 1., nan, nan]],\n",
       "\n",
       "        [[1., 1., nan, 1.],\n",
       "         [1., 1., nan, nan],\n",
       "         [1., 1., nan, 1.]]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2,3,4))\n",
    "TSRandom2Value(magnitude=.5, sel_steps=slice(2, None))(t, split_idx=0).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = TSTensor(torch.ones(2,3,100))\n",
    "test_gt(np.isnan(TSRandom2Value(magnitude=.5)(t, split_idx=0)).sum().item(), 0)\n",
    "t = TSTensor(torch.ones(2,3,100))\n",
    "test_gt(np.isnan(TSRandom2Value(magnitude=.5, sel_vars=[0,2])(t, split_idx=0)[:, [0,2]]).sum().item(), 0)\n",
    "t = TSTensor(torch.ones(2,3,100))\n",
    "test_eq(np.isnan(TSRandom2Value(magnitude=.5, sel_vars=[0,2])(t, split_idx=0)[:, 1]).sum().item(), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSMask2Value(RandTransform):\n",
    "    \"Randomly sets selected variables of type `TSTensor` to predefined value (default: np.nan)\"\n",
    "    order = 90\n",
    "    def __init__(self, mask_fn, value=np.nan, sel_vars=None, **kwargs):\n",
    "        self.sel_vars = sel_vars\n",
    "        self.mask_fn = mask_fn\n",
    "        self.value = value\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def encodes(self, o:TSTensor):\n",
    "        mask = self.mask_fn(o)\n",
    "        if self.sel_vars is not None:\n",
    "            mask[:, self.sel_vars] = False\n",
    "        return o.masked_fill(mask, self.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = TSTensor(torch.ones(2,3,100))\n",
    "def _mask_fn(o, r=.15, value=np.nan):\n",
    "    return torch.rand_like(o) > (1-r)\n",
    "test_gt(np.isnan(TSMask2Value(_mask_fn)(t, split_idx=0)).sum().item(), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def self_mask(o):\n",
    "    mask1 = torch.isnan(o)\n",
    "    mask2 = rotate_axis0(mask1)\n",
    "    return torch.logical_and(mask2, ~mask1)\n",
    "\n",
    "\n",
    "class TSSelfDropout(RandTransform):\n",
    "    \"\"\"Applies dropout to a tensor with nan values by rotating axis=0 inplace\"\"\"\n",
    "    order = 90\n",
    "    def encodes(self, o: TSTensor):\n",
    "        mask = self_mask(o)\n",
    "        o[mask] = np.nan\n",
    "        return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.30000001192092896, 0.49000000953674316)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TSTensor(torch.ones(2,3,100))\n",
    "mask = torch.rand_like(t) > .7\n",
    "t[mask] = np.nan\n",
    "nan_perc = np.isnan(t).float().mean().item()\n",
    "t2 = TSSelfDropout()(t, split_idx=0)\n",
    "test_gt(torch.isnan(t2).float().mean().item(), nan_perc)\n",
    "nan_perc, torch.isnan(t2).float().mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "all_TS_randaugs = [\n",
    "\n",
    "    TSIdentity,\n",
    "\n",
    "    # Noise\n",
    "    (TSMagAddNoise, 0.1, 1.),\n",
    "    (TSGaussianNoise, .01, 1.),\n",
    "    (partial(TSMagMulNoise, ex=0), 0.1, 1),\n",
    "    (partial(TSTimeNoise, ex=0), 0.1, 1.),\n",
    "    (partial(TSRandomFreqNoise, ex=0), 0.1, 1.),\n",
    "    partial(TSShuffleSteps, ex=0),\n",
    "    (TSRandomTimeScale, 0.05, 0.5),\n",
    "    (TSRandomTimeStep, 0.05, 0.5),\n",
    "    (partial(TSFreqDenoise, ex=0), 0.1, 1.),\n",
    "    (TSRandomLowRes, 0.05, 0.5),\n",
    "    (TSInputDropout, 0.05, .5),\n",
    "\n",
    "    # Magnitude\n",
    "    (partial(TSMagWarp, ex=0), 0.02, 0.2),\n",
    "    (TSMagScale, 0.2, 1.),\n",
    "    (partial(TSMagScalePerVar, ex=0), 0.2, 1.),\n",
    "    (partial(TSRandomConv, ex=0), .05, .2),\n",
    "    partial(TSBlur, ex=0),\n",
    "    partial(TSSmooth, ex=0),\n",
    "    partial(TSDownUpScale, ex=0),\n",
    "    partial(TSRandomDownUpScale, ex=0),\n",
    "    (TSRandomTrend, 0.1, 0.5),\n",
    "    TSVerticalFlip,\n",
    "    (TSVarOut, 0.05, 0.5),\n",
    "    (TSCutOut, 0.05, 0.5),\n",
    "\n",
    "    # Time\n",
    "    (partial(TSTimeWarp, ex=0), 0.02, 0.2),\n",
    "    (TSWindowWarp, 0.05, 0.5),\n",
    "    (TSRandomSize, 0.05, 1.),\n",
    "    TSHorizontalFlip,\n",
    "    (TSTranslateX, 0.1, 0.5),\n",
    "    (TSRandomShift, 0.02, 0.2),\n",
    "    (TSRandomZoomIn, 0.05, 0.5),\n",
    "    (TSWindowSlicing, 0.05, 0.2),\n",
    "    (TSRandomZoomOut, 0.05, 0.5),\n",
    "    (TSRandomLookBackOut, 0.1, 1.),\n",
    "    (TSRandomResizedLookBack, 0.1, 1.),\n",
    "    (TSTimeStepOut, 0.01, 0.2),\n",
    "    (TSRandomCropPad, 0.05, 0.5),\n",
    "    (TSRandomResizedCrop, 0.05, 0.5),\n",
    "    (TSMaskOut, 0.01, 0.2),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class RandAugment(RandTransform):\n",
    "    order = 90\n",
    "    def __init__(self, tfms:list, N:int=1, M:int=3, **kwargs):\n",
    "        '''\n",
    "        tfms   : list of tfm functions (not called)\n",
    "        N      : number of tfms applied to each batch (usual values 1-3)\n",
    "        M      : tfm magnitude multiplier (1-10, usually 3-5). Only works if tfms are tuples (tfm, min, max)\n",
    "        kwargs : RandTransform kwargs\n",
    "        '''\n",
    "        super().__init__(**kwargs)\n",
    "        if not isinstance(tfms, list): tfms = [tfms]\n",
    "        self.tfms, self.N, self.magnitude = tfms, min(len(tfms), N), M / 10\n",
    "        self.n_tfms, self.tfms_idxs = len(tfms), np.arange(len(tfms))\n",
    "\n",
    "    def encodes(self, o:(NumpyTensor, TSTensor)):\n",
    "        if not self.N or not self.magnitude: return o\n",
    "        tfms = self.tfms if self.n_tfms==1 else L(self.tfms)[random_choice(np.arange(self.n_tfms), self.N, replace=False)]\n",
    "        tfms_ = []\n",
    "        for tfm in tfms:\n",
    "            if isinstance(tfm, tuple):\n",
    "                t, min_val, max_val = tfm\n",
    "                tfms_ += [t(magnitude=self.magnitude * float(max_val - min_val) + min_val)]\n",
    "            else:  tfms_ += [tfm()]\n",
    "        output = compose_tfms(o, tfms_, split_idx=self.split_idx)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ne(RandAugment(TSMagAddNoise, N=5, M=10)(xb, split_idx=0), xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TestTfm(RandTransform):\n",
    "    \"Utility class to test the output of selected tfms during training\"\n",
    "    def __init__(self, tfm, magnitude=1., ex=None, **kwargs):\n",
    "        self.tfm, self.magnitude, self.ex = tfm, magnitude, ex\n",
    "        self.tfmd, self.shape = [], []\n",
    "        super().__init__(**kwargs)\n",
    "    def encodes(self, o: TSTensor):\n",
    "        if not self.magnitude or self.magnitude <= 0: return o\n",
    "        output = self.tfm(o, split_idx=self.split_idx)\n",
    "        self.tfmd.append(torch.equal(o, output))\n",
    "        self.shape.append(o.shape)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_tfm_name(tfm):\n",
    "    if isinstance(tfm, tuple): tfm = tfm[0]\n",
    "    if hasattr(tfm, \"func\"): tfm = tfm.func\n",
    "    if hasattr(tfm, \"__name__\"): return tfm.__name__\n",
    "    elif hasattr(tfm, \"__class__\") and hasattr(tfm.__class__, \"__name__\"): return tfm.__class__.__name__\n",
    "    else: return tfm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(get_tfm_name(partial(TSMagScale()))==get_tfm_name((partial(TSMagScale()), 0.1, .05))==get_tfm_name(TSMagScale())==get_tfm_name((TSMagScale(), 0.1, .05)), True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_TS_randaugs_names = [get_tfm_name(t) for t in all_TS_randaugs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/010_data.transforms.ipynb saved at 2024-02-10 21:46:00\n",
      "Correct notebook to script conversion! 😃\n",
      "Saturday 10/02/24 21:46:03 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
